{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7ede42eb-0cd2-4fa7-9559-302d37778e39",
   "metadata": {},
   "source": [
    "# jaxKAN Results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86555810-1110-4127-9002-368446f8d163",
   "metadata": {},
   "source": [
    "Some preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1829191f-7f46-4575-ba94-2f09d5e6fa5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import optax\n",
    "import flax\n",
    "from flax import linen as nn\n",
    "\n",
    "import sys\n",
    "import os\n",
    "\n",
    "import time\n",
    "\n",
    "# Add /src to path\n",
    "path_to_src = os.path.abspath(os.path.join(os.getcwd(), '../../../src'))\n",
    "if path_to_src not in sys.path:\n",
    "    sys.path.append(path_to_src)\n",
    "\n",
    "from KAN import KAN\n",
    "from PIKAN import *\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c246f07-0ef3-498c-a341-dd69dc5a7152",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "## Diffusion Equation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a901b2fb-52c7-45b4-be6f-f7d7e98d158d",
   "metadata": {},
   "source": [
    "### Collocation Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43a706a4-3906-4441-a6e1-24f9f5001f9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Collocation points for PDE\n",
    "N = 2**12\n",
    "collocs = jnp.array(sobol_sample(np.array([0,0]), np.array([1,1]), N)) # (4096, 2)\n",
    "\n",
    "# Generate Collocation points for BCs\n",
    "N = 2**6\n",
    "\n",
    "BC1_colloc = jnp.array(sobol_sample(np.array([0,0]), np.array([0,1]), N)) # (64, 2)\n",
    "BC1_data = jnp.sin(np.pi*BC1_colloc[:,1]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC2_colloc = jnp.array(sobol_sample(np.array([0,0]), np.array([1,0]), N)) # (64, 2)\n",
    "BC2_data = jnp.zeros(BC2_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC3_colloc = jnp.array(sobol_sample(np.array([0,1]), np.array([1,1]), N)) # (64, 2)\n",
    "BC3_data = jnp.zeros(BC3_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "# Create lists for BCs\n",
    "bc_collocs = [BC1_colloc, BC2_colloc, BC3_colloc]\n",
    "bc_data = [BC1_data, BC2_data, BC3_data]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e85ac1ab-4cff-4c09-884e-6dd4d22e2bca",
   "metadata": {},
   "source": [
    "### Loss Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c434313-ee08-4f38-9dbe-cf7bb1af6f93",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pde_loss(params, collocs, state):\n",
    "    # Eq. parameter\n",
    "    #D = jnp.array(0.15, dtype=float)\n",
    "    D = jnp.array(1.0, dtype=float)\n",
    "    \n",
    "    # Define the model function\n",
    "    variables = {'params' : params, 'state' : state}\n",
    "    \n",
    "    def u(vec_x):\n",
    "        y, spl = model.apply(variables, vec_x)\n",
    "        return y\n",
    "        \n",
    "    # Physics Loss Terms\n",
    "    u_t = gradf(u, 0, 1)  # 1st order derivative of t\n",
    "    u_xx = gradf(u, 1, 2) # 2nd order derivative of x\n",
    "    source = jnp.exp(-collocs[:,[0]])*(-jnp.sin(jnp.pi*collocs[:,[1]]) + (jnp.pi**2)*jnp.sin(jnp.pi*collocs[:,[1]]))\n",
    "    \n",
    "    # Residual\n",
    "    pde_res = u_t(collocs) - D*u_xx(collocs) -source\n",
    "    \n",
    "    return pde_res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b23c625a-4617-46fb-a1ba-6c7c5ad4e4e1",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "460f30c0-463a-4e7d-a466-c39135e7e174",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "key = jax.random.PRNGKey(0)\n",
    "\n",
    "layer_dims = [2, 6, 6, 1]\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(key, jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.001\n",
    "lr_vals['scales'] = {0 : 1.0}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a8f026e-18d3-4a3e-b31b-03c4d39f1a4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 50000\n",
    "\n",
    "model, variables, train_losses = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=False, loc_w=None, nesterov=False, \n",
    "                                             num_epochs=num_epochs, grid_extend={0 : 3}, grid_adapt=[], \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa5b70a6-860e-4a73-80d7-f79b99b9c231",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "2607816d-1611-4267-a776-a1959d825a56",
   "metadata": {},
   "source": [
    "### Plot & Save Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d538259-5415-4891-9097-87d8c998205b",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_t, N_x = 100, 256\n",
    "\n",
    "t = np.linspace(0.0, 1.0, N_t)\n",
    "x = np.linspace(0.0, 1.0, N_x)\n",
    "T, X = np.meshgrid(t, x, indexing='ij')\n",
    "coords = np.stack([T.flatten(), X.flatten()], axis=1)\n",
    "\n",
    "output, _ = model.apply(variables, jnp.array(coords))\n",
    "resplot = np.array(output).reshape(N_t, N_x)\n",
    "\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.pcolormesh(T, X, resplot, shading='auto', cmap='Spectral_r')\n",
    "plt.colorbar()\n",
    "\n",
    "plt.title('Solution of Diffusion Equation')\n",
    "plt.xlabel('t')\n",
    "\n",
    "plt.ylabel('x')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4f1e6ef-6bc9-44ff-bbe7-989ea3ba4e79",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e236db8-ff7c-4356-94e6-3ce00e1590de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Also calculate analytical solution\n",
    "def diff_exact(t,x):\n",
    "    #D = 0.15\n",
    "    return np.sin(np.pi*x)*np.exp(-t)\n",
    "    #return np.sin(np.pi*x)*np.exp(-D*np.pi**2*t)\n",
    "\n",
    "ref_result = diff_exact(T, X)\n",
    "\n",
    "# Write variables dict\n",
    "serialized_variables = flax.serialization.to_bytes(variables)\n",
    "\n",
    "with open('jaxKAN models/eq1-jaxkan.pkl', 'wb') as f:\n",
    "    f.write(serialized_variables)\n",
    "    \n",
    "np.savez('../Plots/data/eq1-jaxkan.npz', t=t, x=x, result=resplot, ref=ref_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "498f7551-91e2-4953-b0ac-eb15bed218e6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "397dada5-ff95-43b2-9cde-cb5e101f1877",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "## Helmholtz Equation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ec48bb3-9e86-4ddf-89b6-4a94dca0fc7f",
   "metadata": {},
   "source": [
    "### Collocation Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "803f673f-f4b6-4798-8b65-425bcc90609e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Collocation points for PDE\n",
    "N = 2**12\n",
    "collocs = jnp.array(sobol_sample(np.array([-1,-1]), np.array([1,1]), N)) # (4096, 2)\n",
    "\n",
    "# Generate Collocation points for BCs\n",
    "N = 2**6\n",
    "\n",
    "BC1_colloc = jnp.array(sobol_sample(np.array([-1,-1]), np.array([-1,1]), N)) # (64, 2)\n",
    "BC1_data = jnp.zeros(BC1_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC2_colloc = jnp.array(sobol_sample(np.array([-1,-1]), np.array([1,-1]), N)) # (64, 2)\n",
    "BC2_data = jnp.zeros(BC2_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC3_colloc = jnp.array(sobol_sample(np.array([1,-1]), np.array([1,1]), N)) # (64, 2)\n",
    "BC3_data = jnp.zeros(BC3_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC4_colloc = jnp.array(sobol_sample(np.array([-1,1]), np.array([1,1]), N)) # (64, 2)\n",
    "BC4_data = jnp.zeros(BC4_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "# Create lists for BCs\n",
    "bc_collocs = [BC1_colloc, BC2_colloc, BC3_colloc, BC4_colloc]\n",
    "bc_data = [BC1_data, BC2_data, BC3_data, BC4_data]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17ef76f8-e3e4-4298-bfc2-e51710a13922",
   "metadata": {},
   "source": [
    "### Loss Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cef0e437-bfa0-49ce-b0b4-84e94d5295f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pde_loss(params, collocs, state):\n",
    "    # Eq. parameters\n",
    "    k = jnp.array(1.0, dtype=float)\n",
    "    a1 = jnp.array(1.0, dtype=float)\n",
    "    a2 = jnp.array(4.0, dtype=float)\n",
    "    \n",
    "    # Define the model function\n",
    "    variables = {'params' : params, 'state' : state}\n",
    "    \n",
    "    def u(vec_x):\n",
    "        y, spl = model.apply(variables, vec_x)\n",
    "        return y\n",
    "        \n",
    "    # Physics Loss Terms\n",
    "    u_xx = gradf(u, 0, 2)  # 2nd order derivative of x\n",
    "    u_yy = gradf(u, 1, 2) # 2nd order derivative of y\n",
    "\n",
    "    sines = jnp.sin(a1*jnp.pi*collocs[:,[0]])*jnp.sin(a2*jnp.pi*collocs[:,[1]])\n",
    "    source = -((a1*jnp.pi)**2)*sines - ((a2*jnp.pi)**2)*sines + k*sines\n",
    "    \n",
    "    # Residual\n",
    "    pde_res = u_xx(collocs) + u_yy(collocs) + (k**2)*u(collocs) - source\n",
    "    \n",
    "    return pde_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "215d1120-3dec-449c-90b1-1b08752546d7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ed3a7ae2-f96f-467c-ab23-63be4e5f717e",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88984314-05c3-4ac3-8ea2-d162666ebe8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "key = jax.random.PRNGKey(0)\n",
    "\n",
    "layer_dims = [2, 6, 6, 1]\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(key, jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.001\n",
    "lr_vals['scales'] = {0 : 1.0}\n",
    "\n",
    "# Define epochs for grid adaptation\n",
    "adapt_every = 150\n",
    "adapt_stop = 5000\n",
    "grid_adapt = [i * adapt_every for i in range(1, (adapt_stop // adapt_every) + 1)]\n",
    "grid_adapt = []\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 3}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float),\n",
    "          jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]\n",
    "\n",
    "# Initialize RBA weights\n",
    "loc_w = [jnp.ones((collocs.shape[0],1)), jnp.ones((BC1_colloc.shape[0],1)),\n",
    "         jnp.ones((BC2_colloc.shape[0],1)), jnp.ones((BC3_colloc.shape[0],1)), jnp.ones((BC4_colloc.shape[0],1))]\n",
    "loc_w = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1bdd154-7e7d-4145-a4ac-a3db8af13917",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 50000\n",
    "\n",
    "model, variables, train_losses = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, lr_vals=lr_vals, adapt_state=False, loc_w=loc_w, nesterov=False, num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c700c9c-ae5c-43b2-8a33-41a489aa93b7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b03f4489-a8ef-4eda-aaf1-fd6a21bbd26b",
   "metadata": {},
   "source": [
    "### Plot & Save Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b539ed79-6ba7-4ec9-8443-bb773dd988bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_x, N_y = 100, 256\n",
    "\n",
    "x = np.linspace(-1.0, 1.0, N_x)\n",
    "y = np.linspace(-1.0, 1.0, N_y)\n",
    "X, Y = np.meshgrid(x, y, indexing='ij')\n",
    "coords = np.stack([X.flatten(), Y.flatten()], axis=1)\n",
    "\n",
    "output, _ = model.apply(variables, jnp.array(coords))\n",
    "resplot = np.array(output).reshape(N_x, N_y)\n",
    "\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.pcolormesh(X, Y, resplot, shading='auto', cmap='Spectral_r')\n",
    "plt.colorbar()\n",
    "\n",
    "plt.title('Solution of Helmholtz Equation')\n",
    "plt.xlabel('x')\n",
    "\n",
    "plt.ylabel('y')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac4c7a1e-b107-48a4-a18c-470c490f0fd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05c986eb-bdda-4fba-be2e-478e263f2c25",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Also calculate analytical solution\n",
    "def helm_exact(x,y):\n",
    "    a1 = 1.0\n",
    "    a2 = 4.0\n",
    "    return np.sin(a1*np.pi*x)*np.sin(a2*np.pi*y)\n",
    "\n",
    "ref_result = helm_exact(X, Y)\n",
    "\n",
    "# Write variables dict\n",
    "serialized_variables = flax.serialization.to_bytes(variables)\n",
    "\n",
    "with open('jaxKAN models/eq2-jaxkan.pkl', 'wb') as f:\n",
    "    f.write(serialized_variables)\n",
    "\n",
    "np.savez('../Plots/data/eq2-jaxkan.npz', x=x, y=y, result=resplot, ref=ref_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f730d81f-6188-4df8-a20e-4b99a4fa97e3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ccbccfe9-19ce-4ec8-9969-3d08829bad5e",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Burgers' Equation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9bcfe13-7b1b-4be8-9bbb-d08c9ec579c6",
   "metadata": {},
   "source": [
    "### Collocation Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ee5a66f-9689-42ba-bfd4-94eaf44e5879",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Collocation points for PDE\n",
    "N = 2**12\n",
    "collocs = jnp.array(sobol_sample(np.array([0,-1]), np.array([1,1]), N)) # (4096, 2)\n",
    "\n",
    "# Generate Collocation points for BCs\n",
    "N = 2**6\n",
    "\n",
    "BC1_colloc = jnp.array(sobol_sample(np.array([0,-1]), np.array([0,1]), N)) # (64, 2)\n",
    "BC1_data = - jnp.sin(np.pi*BC1_colloc[:,1]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC2_colloc = jnp.array(sobol_sample(np.array([0,-1]), np.array([1,-1]), N)) # (64, 2)\n",
    "BC2_data = jnp.zeros(BC2_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC3_colloc = jnp.array(sobol_sample(np.array([0,1]), np.array([1,1]), N)) # (64, 2)\n",
    "BC3_data = jnp.zeros(BC3_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "# Create lists for BCs\n",
    "bc_collocs = [BC1_colloc, BC2_colloc, BC3_colloc]\n",
    "bc_data = [BC1_data, BC2_data, BC3_data]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f1b656a-0e1a-45e4-9adf-2c5369dbcfeb",
   "metadata": {},
   "source": [
    "### Loss Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb464672-377d-49e3-ace4-e6eb478a27ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pde_loss(params, collocs, state):\n",
    "    # Eq. parameter\n",
    "    v = jnp.array(0.01/jnp.pi, dtype=float)\n",
    "    \n",
    "    # Define the model function\n",
    "    variables = {'params' : params, 'state' : state}\n",
    "    \n",
    "    def u(vec_x):\n",
    "        y, spl = model.apply(variables, vec_x)\n",
    "        return y\n",
    "        \n",
    "    # Physics Loss Terms\n",
    "    u_t = gradf(u, 0, 1)  # 1st order derivative of t\n",
    "    u_x = gradf(u, 1, 1)  # 1st order derivative of x\n",
    "    u_xx = gradf(u, 1, 2) # 2nd order derivative of x\n",
    "    \n",
    "    # Residual\n",
    "    pde_res = u_t(collocs) + u(collocs)*u_x(collocs) - v*u_xx(collocs)\n",
    "    \n",
    "    return pde_res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "086c76cf-3ea2-446d-ab86-3d9d00af6f2d",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b604b5e9-34b5-428c-a5b5-9f89136ac3b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "key = jax.random.PRNGKey(0)\n",
    "\n",
    "layer_dims = [2, 6, 6, 1]\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(key, jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.001\n",
    "lr_vals['scales'] = {0 : 1.0}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb32b6b5-de7c-4bd1-8b41-bb7dcbfef07a",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 50000\n",
    "\n",
    "model, variables, train_losses = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=False, loc_w=None, nesterov=False, \n",
    "                                             num_epochs=num_epochs, grid_extend={0 : 3}, grid_adapt=[], \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30c65406-b054-4ecc-b30a-2ac5c77de7a6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d457bd37-090a-4951-be1c-334a48d84a13",
   "metadata": {},
   "source": [
    "### Plot & Save Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ca2af78-cfaa-4795-a6fe-c0da50aab8da",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_t, N_x = 100, 256\n",
    "\n",
    "t = np.linspace(0.0, 1.0, N_t)\n",
    "x = np.linspace(-1.0, 1.0, N_x)\n",
    "T, X = np.meshgrid(t, x, indexing='ij')\n",
    "coords = np.stack([T.flatten(), X.flatten()], axis=1)\n",
    "\n",
    "output, _ = model.apply(variables, jnp.array(coords))\n",
    "resplot = np.array(output).reshape(N_t, N_x)\n",
    "\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.pcolormesh(T, X, resplot, shading='auto', cmap='Spectral_r')\n",
    "plt.colorbar()\n",
    "\n",
    "plt.title('Solution of Burgers Equation')\n",
    "plt.xlabel('t')\n",
    "\n",
    "plt.ylabel('x')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b600c570-70ed-4c11-89b8-15b21421dab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d6fa5c0-bf2b-4526-8f39-ae7fcdc91f2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ref = np.load('../External Data/Burgers.npz')\n",
    "\n",
    "# Write variables dict\n",
    "serialized_variables = flax.serialization.to_bytes(variables)\n",
    "\n",
    "with open('jaxKAN models/eq3-jaxkan.pkl', 'wb') as f:\n",
    "    f.write(serialized_variables)\n",
    "\n",
    "np.savez('../Plots/data/eq3-jaxkan.npz', t=t, x=x, result=resplot, ref=ref['usol'].T)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "907f82a3-fe52-4386-bfae-4dc79f417076",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Allen-Cahn Equation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7926820-4109-4554-9a87-6d76a7d9a5c8",
   "metadata": {},
   "source": [
    "### Collocation Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bc53dd0-1d06-4d30-a490-fe074d59e373",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Collocation points for PDE\n",
    "N = 2**12\n",
    "collocs = jnp.array(sobol_sample(np.array([0,-1]), np.array([1,1]), N)) # (4096, 2)\n",
    "\n",
    "# Generate Collocation points for BCs\n",
    "N = 2**6\n",
    "\n",
    "BC1_colloc = jnp.array(sobol_sample(np.array([0,-1]), np.array([0,1]), N)) # (64, 2)\n",
    "BC1_data = ((BC1_colloc[:,1]**2)*jnp.cos(jnp.pi*BC1_colloc[:,1])).reshape(-1,1)\n",
    "\n",
    "BC2_colloc = jnp.array(sobol_sample(np.array([0,-1]), np.array([1,-1]), N)) # (64, 2)\n",
    "BC2_data = -jnp.ones(BC2_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC3_colloc = jnp.array(sobol_sample(np.array([0,1]), np.array([1,1]), N)) # (64, 2)\n",
    "BC3_data = -jnp.ones(BC3_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "# Create lists for BCs\n",
    "bc_collocs = [BC1_colloc, BC2_colloc, BC3_colloc]\n",
    "bc_data = [BC1_data, BC2_data, BC3_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01a276d9-4391-41a5-bb8b-27014f5e6133",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "93f0a506-3caf-4576-9318-92079a84af4c",
   "metadata": {},
   "source": [
    "### Loss Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64b26a49-82bb-4906-8417-7ced56f5f28a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pde_loss(params, collocs, state):\n",
    "    # Eq. parameter\n",
    "    D = jnp.array(0.001, dtype=float)\n",
    "    c = jnp.array(5.0, dtype=float)\n",
    "    \n",
    "    # Define the model function\n",
    "    variables = {'params' : params, 'state' : state}\n",
    "    \n",
    "    def u(vec_x):\n",
    "        y, spl = model.apply(variables, vec_x)\n",
    "        return y\n",
    "        \n",
    "    # Physics Loss Terms\n",
    "    u_t = gradf(u, 0, 1)  # 1st order derivative of t\n",
    "    u_xx = gradf(u, 1, 2) # 2nd order derivative of x\n",
    "    \n",
    "    # Residual\n",
    "    pde_res = u_t(collocs) - D*u_xx(collocs) - c*(u(collocs)-(u(collocs)**3))\n",
    "    \n",
    "    return pde_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c16934b-5e2b-413b-863c-3713e9aa5397",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "8364c526-e5e2-4198-86cb-45518a9b1c4f",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06400699-f7e0-4c17-b68f-c07878cda794",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "key = jax.random.PRNGKey(0)\n",
    "\n",
    "layer_dims = [2, 6, 6, 1]\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(key, jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.001\n",
    "lr_vals['scales'] = {0 : 1.0}\n",
    "\n",
    "# Define epochs for grid adaptation\n",
    "adapt_every = 150\n",
    "adapt_stop = 5000\n",
    "grid_adapt = [i * adapt_every for i in range(1, (adapt_stop // adapt_every) + 1)]\n",
    "grid_adapt = []\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 3}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]\n",
    "\n",
    "# Initialize RBA weights\n",
    "loc_w = [jnp.ones((collocs.shape[0],1)), jnp.ones((BC1_colloc.shape[0],1)),\n",
    "         jnp.ones((BC2_colloc.shape[0],1)), jnp.ones((BC3_colloc.shape[0],1))]\n",
    "loc_w = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7f6dddc-f6c7-4926-969b-08a053ffdde4",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 50000\n",
    "\n",
    "model, variables, train_losses = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, lr_vals=lr_vals, adapt_state=False, loc_w=loc_w, nesterov=False, num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef092575-a73c-432a-98c4-b762bde07d06",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "03aa14a3-81f0-4a66-a10d-b1f968370a0f",
   "metadata": {},
   "source": [
    "### Plot & Save Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b057142e-440d-4fa8-ab23-6a5da58f701b",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_t, N_x = 101, 201\n",
    "\n",
    "t = np.linspace(0.0, 1.0, N_t)\n",
    "x = np.linspace(-1.0, 1.0, N_x)\n",
    "T, X = np.meshgrid(t, x, indexing='ij')\n",
    "coords = np.stack([T.flatten(), X.flatten()], axis=1)\n",
    "\n",
    "output, _ = model.apply(variables, jnp.array(coords))\n",
    "resplot = np.array(output).reshape(N_t, N_x)\n",
    "\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.pcolormesh(T, X, resplot, shading='auto', cmap='Spectral_r')\n",
    "plt.colorbar()\n",
    "\n",
    "plt.title('Solution of Allen-Cahn Equation')\n",
    "plt.xlabel('t')\n",
    "\n",
    "plt.ylabel('x')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4aaf0f2-0f66-483c-a3d7-a0ba8f6e7ff5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a728300f-b441-4c1b-b0a4-16f33f08c031",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "ref = scipy.io.loadmat('../External Data/usol_D_0.001_k_5.mat')\n",
    "\n",
    "# Write variables dict\n",
    "serialized_variables = flax.serialization.to_bytes(variables)\n",
    "\n",
    "with open('jaxKAN models/eq4-jaxkan.pkl', 'wb') as f:\n",
    "    f.write(serialized_variables)\n",
    "\n",
    "np.savez('../Plots/data/eq4-jaxkan.npz', t=t, x=x, result=resplot, ref=ref['u'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c0c67f-0186-43f8-9b6a-46794e2120e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
